{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0894a735",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "import torch\n",
    "\n",
    "TORCH_DTYPE = 'bfloat16'\n",
    "nf4_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type='nf4',\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=getattr(torch, TORCH_DTYPE)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6daae66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-llama2-7b-32k-instructions')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0da5df04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2cde6c53efe4c57866bf7c45e1630c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    'mesolitica/malaysian-llama2-7b-32k-instructions',\n",
    "    use_flash_attention_2 = True,\n",
    "    quantization_config = nf4_config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f18b9442",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_llama_chat(messages):\n",
    "\n",
    "    system = messages[0]['content']\n",
    "    user_query = messages[-1]['content']\n",
    "\n",
    "    users, assistants = [], []\n",
    "    for q in messages[1:-1]:\n",
    "        if q['role'] == 'user':\n",
    "            users.append(q['content'])\n",
    "        elif q['role'] == 'assistant':\n",
    "            assistants.append(q['content'])\n",
    "\n",
    "    texts = [f'<s>[INST] <<SYS>>\\n{system}\\n<</SYS>>\\n\\n']\n",
    "    for u, a in zip(users, assistants):\n",
    "        texts.append(f'{u.strip()} [/INST] {a.strip()} </s><s>[INST] ')\n",
    "    texts.append(f'{user_query.strip()} [/INST]')\n",
    "    prompt = ''.join(texts).strip()\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "07881cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {'role': 'system', 'content': 'awak adalah AI yang mampu jawab segala soalan'},\n",
    "    {'role': 'user', 'content': 'kwsp tu apa'}\n",
    "]\n",
    "prompt = parse_llama_chat(messages)\n",
    "inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67d3c231",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<s> [INST] <<SYS>>\\nawak adalah AI yang mampu jawab segala soalan\\n<</SYS>>\\n\\nkwsp tu apa [/INST] Kumpulan Wang Simpanan Pekerja (KWSP) ialah organisasi yang ditubuhkan untuk menggalakkan simpanan dan pelaburan dalam kalangan rakyat Malaysia untuk persaraan mereka. KWSP menyediakan perkhidmatan simpanan dan pelaburan kepada ahlinya, yang termasuk semua pekerja yang bekerja di Malaysia, dan menawarkan beberapa faedah kepada mereka, seperti dividen tahunan. </s>'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_kwargs = dict(\n",
    "    inputs,\n",
    "    max_new_tokens=1024,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    temperature=0.9,\n",
    "    do_sample=True,\n",
    "    num_beams=1,\n",
    ")\n",
    "r = model.generate(**generate_kwargs)\n",
    "tokenizer.decode(r[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4efaf03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {'role': 'system', 'content': 'awak adalah AI yang mampu jawab segala soalan'},\n",
    "    {'role': 'user', 'content': 'awat malaysia ada jabatan koko, malaysia bukan buat keluaq koko banyak pun'}\n",
    "]\n",
    "prompt = parse_llama_chat(messages)\n",
    "inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bdc41dda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s> [INST] <<SYS>>\n",
      "awak adalah AI yang mampu jawab segala soalan\n",
      "<</SYS>>\n",
      "\n",
      "awat malaysia ada jabatan koko, malaysia bukan buat keluaq koko banyak pun [/INST] Penanaman koko di Malaysia bermula pada tahun 1950-an, dengan pengeluaran koko kering pertama dihasilkan pada tahun 1958. Pada masa itu, negara ini menjadi pembekal koko terbesar di dunia, dan industri ini terus berkembang dan mewujudkan ribuan pekerjaan dalam pemprosesan dan pengeluaran koko.\n",
      "\n",
      "Walau bagaimanapun, penurunan harga koko di pasaran antarabangsa pada tahun 1980-an menyebabkan pengeluaran koko Malaysia menurun, dan negara ini telah mengubah tumpuan daripada penanaman koko kepada komoditi lain seperti minyak kelapa sawit dan getah.\n",
      "\n",
      "Walaupun Malaysia bukan pengeluar koko yang besar, industri koko tempatan masih penting dari segi ekonomi dan sosial. Koko ialah komoditi yang diperdagangkan di bursa tempatan, dan pengeluar koko tempatan menghasilkan koko kering yang dieksport ke seluruh dunia.\n",
      "\n",
      "Jabatan Koko Malaysia ialah sebuah agensi kerajaan yang bertanggungjawab untuk membangunkan industri koko dan mempromosikan pengeluaran koko. Agensi ini bekerjasama dengan industri untuk meningkatkan produktiviti dan memastikan kualiti produk koko Malaysia. Ia juga menggalakkan usaha pembangunan pekebun kecil untuk mempromosikan industri koko dan mewujudkan peluang pekerjaan dan ekonomi di kawasan pedalaman. </s>\n"
     ]
    }
   ],
   "source": [
    "generate_kwargs = dict(\n",
    "    inputs,\n",
    "    max_new_tokens=1024,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    temperature=0.9,\n",
    "    do_sample=True,\n",
    "    num_beams=1,\n",
    ")\n",
    "r = model.generate(**generate_kwargs)\n",
    "print(tokenizer.decode(r[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05afbe68",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
